Skip to content

[feat][plugin] make ATOM mla attention works for vllm#265

Open
XiaobingSuper wants to merge 13 commits intoROCm:mainfrom
XiaobingSuper:xiaobing/oot_kimi
Open

[feat][plugin] make ATOM mla attention works for vllm#265
XiaobingSuper wants to merge 13 commits intoROCm:mainfrom
XiaobingSuper:xiaobing/oot_kimi

Conversation

@XiaobingSuper
Copy link

@XiaobingSuper XiaobingSuper commented Mar 4, 2026

Motivation

Following #126, this PR makes ATOM mla attention work for the vLLM plugin model. Note: the sparse mla is not supported now and will be implemented in the next step.

Technical Details

The design tails can be seen in #126.

Test Plan

This PR does a test for Kimi-K2-Thinking-MXFP4 mode with TP4 on mi355:

export SAFETENSORS_FAST_GPU=1
export VLLM_ROCM_USE_AITER=1
export VLLM_RPC_TIMEOUT=1800000

export VLLM_CACHE_ROOT=/root/.cache/vllm
export TORCHINDUCTOR_CACHE_DIR=/root/.cache/inductor
export HIP_VISIBLE_DEVICES=0,1,2,3
# quick allreduce
export AITER_QUICK_REDUCE_QUANTIZATION=INT4
export ATOM_PROFILER_MORE=1

export VLLM_TORCH_PROFILER_RECORD_SHAPES=1

model_path= Kimi-K2-Thinking-MXFP4
vllm serve $model_path \
    --host localhost \
    --port 8001 \
    --tensor-parallel-size 4 \
    --enable-expert-parallel \
    --trust-remote-code \
    --disable-log-requests \
    --gpu_memory_utilization 0.9 \
    --compilation-config '{"cudagraph_mode": "FULL_AND_PIECEWISE"}' \
    --kv-cache-dtype fp8 \
    --max-num-batched-tokens 18432 \
    --max-model-len 16384 \
    --no-enable-prefix-caching

Test Result

gsmk result"

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  |0.9371|±  |0.0067|
|     |       |strict-match    |     3|exact_match|↑  |0.9363|±  |0.0067|
  

Submission Checklist

Copilot AI review requested due to automatic review settings March 4, 2026 11:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds vLLM plugin-mode support for ATOM’s MLA attention path (non-sparse), including backend selection, metadata plumbing, and DeepSeek V3 model registration/loading so MLA can run end-to-end under vLLM.

Changes:

  • Route vLLM’s use_mla attention selection to an ATOM MLA backend and add MLA-specific plugin-mode metadata builders.
  • Implement plugin-mode MLA forward/prefill/decode logic (including positions capture for graph mode).
  • Register DeepSeek V3 as a supported vLLM plugin model and add a plugin-mode load_weights implementation.

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
atom/utils/backends.py Extends compilation-cache hashing to ignore <frozen os> traced “files”.
atom/plugin/vllm/register.py Patches vLLM process_weights_after_loading for Attention/MLAAttention.
atom/plugin/vllm/platform.py Selects ATOM MLA backend when attn_selector_config.use_mla is true.
atom/plugin/vllm/model_wrapper.py Copies positions into a static buffer for graph-mode MLA correctness.
atom/plugin/attention_mla.py New: plugin-mode MLAAttention implementation helpers (prefill/decode/DCP).
atom/plugin/attention.py Adds MLA plugin-mode metadata builders + backend wiring; renames plugin metadata class.
atom/models/deepseek_v2.py Adds DeepSeek V3 support + plugin-mode load_weights.
atom/model_ops/utils.py Removes duplicate per_tensor_dequantize implementation (keeps the canonical one).
atom/model_ops/paged_attention.py Integrates vLLM MLAAttention usage and allocates a shared positions buffer.
atom/model_ops/linear.py Ensures activation tensor is contiguous before quantizer .view() calls.
atom/model_ops/base_attention.py Adjusts MLA unified-attn path to apply o_proj outside MLA impl.
atom/model_ops/attentions/aiter_mla.py Decorates MLA backend/builder for plugin mode; builder init adjustments.
atom/model_ops/attentions/aiter_attention.py Removes unused import.
atom/model_ops/attention_mla.py Adds plugin-mode hooks/decorator and splits v_up and o_proj responsibilities.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# quant_func will call view, so we need to call contiguous to avoid view error
x, x_scale = quant_func(
x,
x.contiguous(),
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is required for the deepseek-r1 model, where x is a sliced tensor that cannot be viewed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will do something else to avoid contiguous which will introduce mem copy here, all our quant should support non-contiguous tensor already... we hit any issue here?

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then do contiguous at that place maybe, i don't like loss any perf

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I updated the code, and do contiguous at plugin side.

@XiaobingSuper
Copy link
Author

DeepSeek-R1-0528 with TP=8 has also been tested:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     3|exact_match|↑  |0.9424|±  |0.0064|
|     |       |strict-match    |     3|exact_match|↑  |0.9363|±  |0.0067|

Copilot AI review requested due to automatic review settings March 4, 2026 13:00
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left my comment FYI.

Copilot AI review requested due to automatic review settings March 5, 2026 05:49
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 5 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

ChuanLi1101
ChuanLi1101 previously approved these changes Mar 5, 2026
Copy link
Collaborator

@ChuanLi1101 ChuanLi1101 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the quick turnaround.

Copilot AI review requested due to automatic review settings March 6, 2026 07:27
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

ZhangLirong-amd
ZhangLirong-amd previously approved these changes Mar 6, 2026
# dummy run: skip real attention and return
output_shape = list(q.shape)
output_shape[-1] = 7168
output_shape[-1] = self.num_heads * self.v_head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.num_heads * self.v_head_dim looks like not eaquals to 7168 for deepseek

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is o_proj's input, see atom path:
image and plugin path:
image

The reason is that vllm do a_proj outside of the attention backend.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plugin path is also in our repo... then why we have to move o_proj out of attn

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is for the fallback path, i.e., for plugin mode, but use vllm attn backend, because we use vllm MLAAttention class(here is self.attn), the forward path doesn't has o_proj, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L640, this is only for attention compute.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 12 out of 12 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings March 6, 2026 15:42
**kwargs,
)

impl_args["head_size" if self.use_mla else "head_dim"] = head_dim
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is also comes from vllm? i would like we always use head_dim

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, vllm use head_size, see https://github.com/vllm-project/vllm/blob/v0.15.1/vllm/attention/layer.py#L579. Before this PR 6c40248, it also use head_size.

self.layer_num = layer_num

def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: Optional[torch.dtype] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zejunchen-zejun we need add this arg?

Copy link
Author

@XiaobingSuper XiaobingSuper Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 190 to +208
@@ -153,7 +204,8 @@ def __init__(
k_norm=k_norm,
**kwargs,
)

impl_args["head_size" if self.use_mla else "head_dim"] = head_dim
self.impl = impl_cls(**impl_args)
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_mla is True, impl_cls is atom.model_ops.attention_mla.MLAAttention which now expects head_size (not head_dim). This code always includes head_dim in impl_args and then also adds head_size, so MLAAttention will receive an unexpected head_dim kwarg and raise at construction time. Build impl_args conditionally (only pass head_dim for MHA, and only pass head_size for MLA), or remove the unconditional head_dim entry before instantiating the MLA impl.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants